Skip to content

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677

Merged
sudhakarsingh27 merged 29 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn
Mar 25, 2026
Merged

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
sudhakarsingh27 merged 29 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

Description

cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get Stats from cuDNN and Max tensor if return_max_logit=True. (Note that Stats = log(SumExp)+Max)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • fused_attn_f16_arbitrary_seqlen.cu
    • Removed references to SumExp tensor as it's not needed since cuDNN returns Stats by default.
    • set generate_stats=True which forces cuDNN to always return Stats tensor (needed in the backward pass)
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py
    • Remove code that manually did Stats = log(SumExp) + Max since cuDNN returns Stats directly and TE doesn't need SumExp from cuDNN
  • Corresponding documentation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Summary

This PR updates the cuDNN-based fused attention forward pass to always return Stats (= log(SumExp) + Max) directly from cuDNN, removing the prior manual Stats = log(SumExp) + Max computation in Python. When return_max_logit=True, cuDNN also returns the Max tensor. The Sum_Exp output is eliminated entirely. The descriptor field generate_max_sum_exp is renamed to return_max_logit, and the aux tensor ordering is updated throughout (Stats is now always index 0; Max is index 1 when present).

Key changes:

  • fused_attn_f16_arbitrary_seqlen.cu: generate_stats is hardcoded to true; Sum_Exp tensor removed; Stats is always set as output; aux-tensor indexing restructured so Stats is always first.
  • fused_attn.py: Removes stats = output_tensors[1] + torch.log(output_tensors[2]) computation; reads Stats and Max directly from cuDNN output at indices 1 and 2.
  • utils.h: FADescriptor_v1::generate_max_sum_exp renamed to return_max_logit.
  • fused_attn.h: Two of the four return_max_logit doc comments updated (per prior review threads, the other two at lines 403 and 556 still need updates).

Two SM120 (Blackwell) regressions were found in fused_attn_f16_arbitrary_seqlen.cu:

  1. The Stats stride-setting condition uses is_ragged_q && cudnn_runtime_version >= 90600 instead of use_ragged_stats, bypassing the sm_arch_ != 120 guard and calling set_ragged_offset(nullptr) on SM120.
  2. The output_S allocation shape condition also drops the sm_arch_ != 120 guard, causing a 3D/4D shape mismatch on SM120 with THD format.

Confidence Score: 2/5

  • Two SM120-specific regressions in the C++ CUDA kernel should be fixed before merging.
  • The overall design is sound and the non-SM120 path looks correct. However, two concrete bugs on SM120 (Blackwell) were introduced: (1) the Stats stride condition drops the sm_arch_ != 120 guard present in use_ragged_stats, passing a null offset_stats pointer to set_ragged_offset() during cuDNN graph construction; (2) the output_S shape allocation also drops the SM120 guard, causing a 3D/4D shape mismatch with what the graph actually builds on SM120. Since Blackwell is a supported production architecture, these are real correctness issues that should be addressed before merging.
  • transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu — both the graph-building Stats stride condition (~line 384) and the output_S shape allocation (~line 1134) need the sm_arch_ != 120 guard restored.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Core change: always generates Stats, removes Sum_Exp, conditionally generates Max. Two SM120-specific bugs introduced: (1) Stats stride uses is_ragged_q && cudnn_runtime_version >= 90600 instead of use_ragged_stats, causing null ragged-offset on SM120; (2) output_S shape allocation drops the sm_arch_ != 120 guard, causing a 3D/4D shape mismatch on SM120.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Python-side updated to reflect new tensor ordering (Stats at index 1, Max at index 2, rng_state at index 3 when return_max_logit=True). Removes the manual log(SumExp)+Max computation. Logic and indexing look correct.
transformer_engine/common/fused_attn/utils.h Renames generate_max_sum_expreturn_max_logit in FADescriptor_v1, which is a clean semantic improvement. The std::tie comparison is updated consistently.
transformer_engine/common/include/transformer_engine/fused_attn.h Doc string updated from "produce Max and Sum_Exp, or Stats" to "produce Max along with Stats" in two locations (lines 209, 269). Two additional comment sites (lines 403, 556) noted as still needing updates per prior review threads.
transformer_engine/pytorch/csrc/extensions/attention.cpp Comment updated to reflect new aux tensor ordering (S, then Max when return_max_logit=True). Minor doc comment fix only — no logic changes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fused_attn_arbitrary_seqlen_fwd] --> B{Aux_CTX_Tensors empty?}
    B -->|Yes - allocate| C[Allocate Stats tensor\nalways, index 0]
    C --> D{return_max_logit?}
    D -->|Yes| E[Allocate Max tensor\nindex 1]
    D -->|No| F[Skip Max]
    E --> G[Allocate rng_state\nindex 2]
    F --> G2[Allocate rng_state\nindex 1]
    B -->|No - use existing| H[Read Stats → devPtrS1\nindex 0]
    H --> I{return_max_logit?}
    I -->|Yes| J[Read Max → devPtrS2\nindex 1]
    I -->|No| K[Skip Max]
    J --> L[Read rng_state\nindex 2]
    K --> L2[Read rng_state\nindex 1]

    subgraph graph_builder [Graph Builder - fused_attn_arbitrary_seqlen_fwd_impl]
        M[sdpa generates O, Stats always] --> N[Stats: set_output=true\nset stride always]
        N --> O{return_max_logit?}
        O -->|Yes| P[Max tensor\nset_logit_max]
        O -->|No| Q[Stats only\nStats_tuple = Stats, null]
        P --> R[Stats_tuple = Stats, Max]
    end

    subgraph python [Python fused_attn_fwd - return_max_logit=True]
        S[output_tensors: out, Stats, Max, rng_state, ...]
        S --> T[aux_ctx_tensors = Stats + rng_state + optional]
        S --> U[max_tensor = output_tensors 2 = Max]
        U --> V[max_logit = amax over batch/seq dims]
        T --> W[return out, aux_ctx_tensors, max_logit]
    end
Loading

Comments Outside Diff (2)

  1. transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu, line 383-388 (link)

    Stats stride condition inconsistent with use_ragged_stats on SM120

    use_ragged_stats is defined as is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120 (line 90), deliberately excluding SM120 because cuDNN on that arch rejects interleaved strides. However, the new condition for the Stats stride is is_ragged_q && cudnn_runtime_version >= 90600 — without the sm_arch_ != 120 guard.

    On SM120 with THD format and cuDNN >= 90600: use_ragged_stats = false, so offset_stats is never initialized (stays as a null shared_ptr). But the condition here evaluates to true, calling Stats->set_ragged_offset(offset_stats) with a null pointer. This is inconsistent with the Max tensor handling below (line 360) which correctly uses use_ragged_stats. The condition should be changed to match:

  2. transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu, line 1132-1139 (link)

    output_S shape allocation missing sm_arch_ != 120 guard

    The old output_S allocation (now only used in the return_max_logit=false branch in the original code) had (sm_arch_ != 120) in its shape condition. The new unified allocation at this line drops that guard.

    On SM120, because b and s_q in the graph are kept as batch/max_seqlen (not replaced with token counts — see lines 100-112), Stats is built with 4D shape {batch, num_attn_heads, max_seqlen_q, 1}. But the new shape condition here — q_format == NVTE_THD && cudnn_runtime_version >= 90600 — will fire on SM120 and allocate {num_tokens_q, num_attn_heads, 1} (3D), creating a shape mismatch with the graph.

    The Max tensor (lines 1144-1145) preserves the sm_arch_ != 120 check correctly; output_S should match:

Reviews (19): Last reviewed commit: "Merge branch 'fix_return_stats_max_cudnn..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Feb 17, 2026

Additional Comments (1)

transformer_engine/pytorch/cpp_extensions/fused_attn.py
Stale docstring: wrong formula for softmaxStats

The public docstring still describes softmaxStats as log(sum(e^(x - max(x)))), which is log(SumExp). However, with this PR, the returned tensor is cuDNN's Stats = log(SumExp) + Max, not just log(SumExp). This formula was already incorrect before this PR (the old code computed Max + log(SumExp) and stored it as stats), but the PR is an opportunity to correct it.

                       softmaxStats: torch.Tensor
                           log(sum(e^(x - max(x)))) + max(x), where x=Q*K.T (i.e. Stats = log(SumExp) + Max)
                           shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

stats = output_tensors[1] + torch.log(output_tensors[2])
# thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the "there's no typo here" :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately added it because I didn't believe it and checked the shapes myself :P

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the fix_return_stats_max_cudnn branch from 2b64738 to e005455 Compare March 10, 2026 19:01
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

/te-ci L2

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
cyanguwa
cyanguwa previously approved these changes Mar 16, 2026
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
sudhakarsingh27 and others added 4 commits March 24, 2026 16:24
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
@sudhakarsingh27 sudhakarsingh27 merged commit e879bf8 into NVIDIA:main Mar 25, 2026
10 of 12 checks passed
KshitijLakhani pushed a commit that referenced this pull request Mar 25, 2026
… always and `Max` when `return_max_logit=True` (#2677)

* cudnn now returns Stats always and Max only with `return_max_logit=true`

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix a typo that caused a bug

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update doc strings

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix more docs

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fixes from the feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update cudnn-frontend to v1.19.1

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update the cudnn frontend

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix a wrong omission

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
vthumbe1503 pushed a commit to ksivaman/TransformerEngine-1 that referenced this pull request Apr 1, 2026
… always and `Max` when `return_max_logit=True` (NVIDIA#2677)

* cudnn now returns Stats always and Max only with `return_max_logit=true`

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix a typo that caused a bug

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update doc strings

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix more docs

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fixes from the feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update cudnn-frontend to v1.19.1

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* update the cudnn frontend

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix a wrong omission

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants